from model.task import update_task
from flask import request
from algo.feature.fill_nulldata import mean_interpolate, most_frequent_interpolate, median_interpolate, \
    liner_interpolate, mice_fill, fillsDic, period_fill, checkNull, not_null_col
from algo.compute_stats import norm_test
from algo.feature import zscore_outlier, kmeans_outlier, scale_data, time_series_decompose, data_smoothing, timeseries_shift
from utils.database_util import read_gp, save_gp
from utils.format_util import py_to_java, get_err_msg, err_formatter
from utils.timeseries_util import adf_tests

from flask import Blueprint
import logging

logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")

feature = Blueprint('feature', __name__)

@feature.route('/smoothing', methods=['POST'])
def smoothing():
    try:
        params = request.form.to_dict()
        logging.info(params)
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")

        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} order by _record_id_".format(source)
        df = read_gp(sql)

        df, results = data_smoothing.run(df, params)
    except Exception as e:
        return err_formatter(e)
    saved = 0
    updated = 0
    if df is not None:
        try:
            saved = save_gp(target, df)
            updated = update_task(taskId, df, target, True)
        except Exception as e:
            logging.error("error occurs in saving results")
            logging.error(e)
            return err_formatter(e)
        finally:
            if saved and updated:
                return py_to_java(str(results))
            else:
                return get_err_msg()


@feature.route('/standardization', methods=['POST'])
def standardization():
    try:
        params = request.form.to_dict()
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")

        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} order by _record_id_".format(source)
        df = read_gp(sql)

        df, results = scale_data.run(df, params)
    except Exception as e:
        return err_formatter(e)
    if df is not None:
        saved = save_gp(target, df)
        updated = update_task(taskId, df, target, True)
    if saved and updated:
        return py_to_java(str(results))
    else:
        return get_err_msg()


@feature.route('/anomaly_stat', methods=['POST'])
def anomaly_stat():
    try:
        params = request.form.to_dict()
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")

        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} order by _record_id_".format(source)
        df = read_gp(sql)

        df, results = zscore_outlier.run(df, params)
    except Exception as e:
        return err_formatter(e)
    if df is not None:
        saved = save_gp(target, df)
        updated = update_task(taskId, df, target, True, results)
    if saved and updated:
        del results["data_vis"]
        return py_to_java(str(results))
    else:
        return get_err_msg()


@feature.route('/anomaly_knn', methods=['POST'])
def anomaly_knn():
    try:
        params = request.form.to_dict()
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")

        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} order by _record_id_".format(source)
        df = read_gp(sql)

        df, results = kmeans_outlier.run(df, params)
    except Exception as e:
        return err_formatter(e)
    if df is not None:
        saved = save_gp(target, df)
        updated = update_task(taskId, df, target, True, results)
    if saved and updated:
        del results["data_vis"]
        return py_to_java(str(results))
    else:
        return get_err_msg()


@feature.route('/timeseries_decompose', methods=['POST'])
def timeseries_decompose():
    try:
        params = request.form.to_dict()
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")

        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} order by _record_id_".format(source)
        df = read_gp(sql)

        df = time_series_decompose.run(df, params)
    except Exception as e:
        return err_formatter(e)
    if df is not None:
        saved = save_gp(target, df)
        updated = update_task(taskId, df, target, True)
    if saved and updated:
        return str({"status": "SUCCESS"})
    else:
        return get_err_msg()


# 通用插补：同期插补，均值插补，众数插补，中位数插补，线性插补
@feature.route('/imputation_stat', methods=['POST'])
def imputation_stat():
    result = {}
    messages = []
    try:
        params = request.form
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")
        method = params.get("method")
        cols = eval(params.get("cols"))
        if len(cols) == 0:
            result["message"] = "请选择特征列"
            return py_to_java(str(result))
        period = params.get("period")
        stat_method = params.get("stat_method")
        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} ORDER BY _record_id_".format(source)
        df = read_gp(sql)
        not_null_cols = not_null_col(df, cols)
        if not_null_cols:
            messages.append("以下特征列不含有缺失值： " + ",\t".join(not_null_cols))
        if method == 'stat_fill':
            if stat_method == 'mean_fill':
                # 对均值插补进行正态分布测试
                norm_non_steady = norm_test(df, cols)
                if norm_non_steady:
                    messages.append("以下特征列不符合正态分布： " + ",\t".join(norm_non_steady))
                # 对均值插补进行adf测试
                adf_non_steady = adf_tests(df, cols)
                if adf_non_steady:
                    messages.append("以下特征列不平稳： " + ",\t".join(adf_non_steady))
                df_fill = mean_interpolate(df, cols)
            elif stat_method == 'mode_fill':
                df_fill = most_frequent_interpolate(df, cols)
            elif stat_method == 'median_fill':
                df_fill = median_interpolate(df, cols)
        elif method == 'linear_fill':
            df_fill = liner_interpolate(df, cols)
        elif method == 'period_fill':
            # 对均值插补进行adf测试
            adf_non_steady = adf_tests(df, cols)
            if adf_non_steady:
                messages.append("以下特征列不平稳： " + ",\t".join(adf_non_steady))
            df_fill = period_fill(df, cols, int(period))
            if type(df_fill) == str:
                result["message"] = df_fill
                return py_to_java(str(result))
            else:
                df_fill = df_fill
        if len(messages) > 0:
            result["message"] = "; ".join(messages)
    except Exception as e:
        return err_formatter(e)
    if df_fill is not None:
        saved = save_gp(target, df_fill)
        highlight_data = fillsDic(df, df_fill, cols)
        result['highlight'] = highlight_data
        updated = update_task(taskId, df, target, True, result)
    if saved and updated:
        return py_to_java(str(result))
    else:
        return get_err_msg()


# 多重插补
@feature.route('/imputation_multi', methods=['POST'])
def imputation_multi():
    result = {}
    try:
        params = request.form
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")
        cols = eval(params.get("cols"))
        if len(cols) == 0:
            result["message"] = "请选择特征列"
            return py_to_java(str(result))
        method = params.get("method")
        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} ORDER BY _record_id_".format(source)
        df = read_gp(sql)
        if method == 'mice_fill':
            checked = checkNull(df, cols)
            if checked and len(cols) > 1:
                df_fill = mice_fill(df, cols)
            else:
                result["message"] = "请选择至少含有一列缺失值的多个特征列"
                return py_to_java(str(result))
    except Exception as e:
        return err_formatter(e)
    if df_fill is not None:
        saved = save_gp(target, df_fill)
        highlight_data = fillsDic(df, df_fill, cols)
        result['highlight'] = highlight_data
        updated = update_task(taskId, df, target, True, result)
    if saved and updated:
        return py_to_java(str(result))
    else:
        return get_err_msg()

# # ADF test接口
# @feature.route('/adf_tests', methods=['POST'])
# def adf_tests_function():
#     params = request.form
#     source = params.get("tableName")
#     cols = eval(params.get("cols"))
#     if "." not in source:
#         source = "dataset." + source
#     sql = "select * from {} ORDER BY _record_id_".format(source)
#     try:
#         df = read_gp(sql)
#         non_steady_col = adf_tests(df, cols)
#     except Exception as e:
#         return str(e)
#     # 返回不平稳字段
#     if non_steady_col:
#         return py_to_java(str(non_steady_col))
#     else:
#         return "所选特征列中无不平稳特征列"

#
# # 正态分布接口
# @feature.route('/norm_test', methods=['POST'])
# def normal_test_function():
#     params = request.form
#     source = params.get("tableName")
#     cols = eval(params.get("cols"))
#     if "." not in source:
#         source = "dataset." + source
#     sql = "select * from {} ORDER BY _record_id_".format(source)
#     try:
#         df = read_gp(sql)
#         non_steady_col = norm_test(df, cols)
#     except Exception as e:
#         return str(e)
#     # 返回不符合正太分布字段
#     if non_steady_col:
#         return py_to_java(str(non_steady_col))
#     else:
#         return "所选特征列中无不符合正太分布特征列"

@feature.route('/timeseries_shift', methods=['POST'])
def shift():
    try:
        params = request.form.to_dict()
        print(params)
        n_in = int(params["n_in"])
        n_out = int(params["n_out"])
        columns = eval(params["cols"])
        drop = int(params["drop"])
        source = params.get("source")
        target = params.get("target")
        taskId = params.get("taskId")
        if "." not in source:
            source = "dataset." + source
        sql = "select * from {} order by _record_id_".format(source)
        df = read_gp(sql)

        df, results = timeseries_shift.run(df, taskId, columns, n_in, n_out, drop)
        print(results)
    except Exception as e:
        return err_formatter(e)
    saved = 0
    updated = 0
    if df is not None:
        try:
            saved = save_gp(target, df)
            updated = update_task(taskId, df, target, True, results)
        except Exception as e:
            logging.error("error occurs in saving results")
            logging.error(e)
            return err_formatter(e)
        finally:
            if saved and updated:
                return results
            else:
                return get_err_msg()
